{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import checklist\n",
    "import spacy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "nlp = spacy.load('en_core_web_sm')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We have to define some output format. In this case, let's assume the output is just the spacy Doc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [],
   "source": [
    "def predict_spacy(inputs):\n",
    "    return list(nlp.pipe(inputs))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 79,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Wrapper just returns dummy confidence in addition to predictions\n",
    "from checklist.pred_wrapper import PredictorWrapper\n",
    "predict_and_conf = PredictorWrapper.wrap_predict(predict_spacy)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "import checklist\n",
    "from checklist.editor import Editor\n",
    "from checklist.perturb import Perturb\n",
    "from checklist.test_types import MFT, INV, DIR\n",
    "from checklist.expect import Expect\n",
    "editor = Editor()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Since our expectations are on tokens rather than full examples, we have to write custom expectation functions.\n",
    "One alternative would be to make the output an array, but then our expectations would be on the whole prediction array rather than on specific tokens.\n",
    "\n",
    "Here is an example expectation function, where we expect specific tokens to be predicted as 'PERSON' and nothing else to be predicted as 'PERSON'."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 82,
   "metadata": {},
   "outputs": [],
   "source": [
    "# This assumes that pred is a spacy Doc, and that 'meta' contains 'first_name' and 'last_name'.\n",
    "def found_people(x, pred, conf, label=None, meta=None):\n",
    "    people = set([meta['first_name'], meta['last_name']])\n",
    "    pass_ = True\n",
    "    for x in pred:\n",
    "        if x.text in people and x.ent_type_ != 'PERSON':\n",
    "            pass_ = False\n",
    "        if x.text not in people and x.ent_type_ == 'PERSON':\n",
    "            pass_ = False\n",
    "    return pass_\n",
    "expect_fn = Expect.single(found_people)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We also have to write a custom printing function if we want to use `test.summary`.\n",
    "Here is one where we show the whole example with ENT_TYPEs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 86,
   "metadata": {},
   "outputs": [],
   "source": [
    "def format_ner(x, pred, conf, label=None, meta=None):\n",
    "    return ' '.join(['%s(%s)' % (x.text, x.ent_type_) for x in pred])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 89,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Predicting 300 examples\n",
      "Test cases:      300\n",
      "Fails (rate):    2 (0.7%)\n",
      "\n",
      "Example fails:\n",
      "I() met() with() Christopher(PERSON) Davies() last(TIME) night(TIME) .()\n",
      "----\n",
      "I() met() with() Rachel(PERSON) Davies() last(TIME) night(TIME) .()\n",
      "----\n"
     ]
    }
   ],
   "source": [
    "t = editor.template('I met with {first_name} {last_name} last night.',  meta=True, nsamples=300)\n",
    "test = MFT(**t, expect=expect_fn)\n",
    "test.run(predict_and_conf)\n",
    "test.summary(format_example_fn=format_ner)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Failure rate is pretty low. Let's see how spacy does with vietnamese names:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 90,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Predicting 300 examples\n",
      "Test cases:      300\n",
      "Fails (rate):    41 (13.7%)\n",
      "\n",
      "Example fails:\n",
      "I() met() with() Do() Đỗ() last(TIME) night(TIME) .()\n",
      "----\n",
      "I() met() with() John(PERSON) Long(PERSON) last(PERSON) night() .()\n",
      "----\n",
      "I() met() with() Adrian(NORP) Kelly() last(TIME) night(TIME) .()\n",
      "----\n"
     ]
    }
   ],
   "source": [
    "first = [x.split()[0] for x in editor.lexicons.male_from.Vietnam +  editor.lexicons.female_from.Vietnam]\n",
    "last = [x.split()[0] for x in editor.lexicons.last_from.Vietnam]\n",
    "t = editor.template('I met with {first_name} {last_name} last night.', first_name=first, last_name=last, meta=True, nsamples=300)\n",
    "test = MFT(**t, expect=expect_fn)\n",
    "test.run(predict_and_conf)\n",
    "test.summary(format_example_fn=format_ner)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Not as good. Let's try brazilian names:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 93,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Predicting 300 examples\n",
      "Test cases:      300\n",
      "Fails (rate):    37 (12.3%)\n",
      "\n",
      "Example fails:\n",
      "I() met() with() Andréa() Ferrari() last(TIME) night(TIME) .()\n",
      "----\n",
      "I() met() with() Claudia(ORG) Lopes(ORG) last(TIME) night(TIME) .()\n",
      "----\n",
      "I() met() with() Elisa(PERSON) dos() last(TIME) night(TIME) .()\n",
      "----\n"
     ]
    }
   ],
   "source": [
    "first = [x.split()[0] for x in editor.lexicons.male_from.Brazil +  editor.lexicons.female_from.Brazil]\n",
    "last = [x.split()[0] for x in editor.lexicons.last_from.Brazil]\n",
    "t = editor.template('I met with {first_name} {last_name} last night.', first_name=first, last_name=last, meta=True, nsamples=300)\n",
    "test = MFT(**t, expect=expect_fn)\n",
    "test.run(predict_and_conf)\n",
    "test.summary(format_example_fn=format_ner)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "checklist",
   "language": "python",
   "name": "checklist"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
